{
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.12"
  },
  "orig_nbformat": 2,
  "kernelspec": {
   "name": "python361264bitblockscondae1e5e96102154286a7dee4f769994df8",
   "display_name": "Python 3.6.12 64-bit ('blocks': conda)"
  },
  "metadata": {
   "interpreter": {
    "hash": "80b57d3b473cded757c1c09da496d17d13901a1bde76b141e0b4105d5712fb37"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2,
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import seaborn as sns   \n",
    "from matplotlib import pyplot as plt \n",
    "import json \n",
    "import pathlib\n",
    "import re\n",
    "\n",
    "import torch \n",
    "from spacy.tokenizer import Tokenizer\n",
    "from spacy.lang.en import English\n",
    "\n",
    "import sys \n",
    "sys.path.insert(0, \"/home/estengel/real_good_robot\")\n",
    "\n",
    "from transformer import TransformerEncoder\n",
    "from language_embedders import RandomEmbedder \n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "def map_k(k):\n",
    "    return re.sub(\"language_embedder\\.\",\"\",k)\n",
    "\n",
    "def load_model(path):\n",
    "    nlp = English()\n",
    "    tokenizer = Tokenizer(nlp.vocab)\n",
    "    vocab = json.load(open(path.joinpath(\"vocab.json\")))\n",
    "    embedder = RandomEmbedder(tokenizer, vocab, 300, trainable=True)\n",
    "    state_dict = torch.load(path.joinpath(\"best.th\"), map_location='cpu')\n",
    "    state_dict = {k:v for k,v in state_dict.items() if \"embedder\" in k}\n",
    "    state_dict = {map_k(k):v for k,v in state_dict.items()}\n",
    "    embedder.load_state_dict(state_dict, strict=True)\n",
    "    embedder.device = \"cpu\"\n",
    "    return embedder \n",
    "\n",
    "\n",
    "def get_color_vectors(model, colors): \n",
    "    embeddings = model.forward(colors)\n",
    "    return embeddings.detach().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = load_model(pathlib.Path(\"/srv/local1/estengel/models/gr_train_augment_only_depthmaps_fixed/\"))\n",
    "colors = ['red','blue','green','yellow', 'purple', 'gray', 'orange','brown']\n",
    "\n",
    "embeddings = get_color_vectors(model, colors)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "output_type": "error",
     "ename": "ValueError",
     "evalue": "n_iter should be at least 250",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-51-d8ede7abcc80>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     14\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     15\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 16\u001b[0;31m \u001b[0mdisplay_tsne_scatterplot_2D\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0membeddings\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcolors\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m<ipython-input-51-d8ede7abcc80>\u001b[0m in \u001b[0;36mdisplay_tsne_scatterplot_2D\u001b[0;34m(word_vectors, colors)\u001b[0m\n\u001b[1;32m      8\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      9\u001b[0m     \u001b[0;31m# three_dim = TSNE(n_components = 3, random_state=0, perplexity = perplexity, learning_rate = learning_rate, n_iter = iteration).fit_transform(word_vectors)[:,:3]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m     \u001b[0mtwo_dim\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mTSNE\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn_components\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrandom_state\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mperplexity\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlearning_rate\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_iter\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit_transform\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mword_vectors\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     11\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     12\u001b[0m     \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtwo_dim\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/blocks/lib/python3.6/site-packages/sklearn/manifold/_t_sne.py\u001b[0m in \u001b[0;36mfit_transform\u001b[0;34m(self, X, y)\u001b[0m\n\u001b[1;32m    889\u001b[0m             \u001b[0mEmbedding\u001b[0m \u001b[0mof\u001b[0m \u001b[0mthe\u001b[0m \u001b[0mtraining\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mlow\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0mdimensional\u001b[0m \u001b[0mspace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    890\u001b[0m         \"\"\"\n\u001b[0;32m--> 891\u001b[0;31m         \u001b[0membedding\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_fit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    892\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0membedding_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0membedding\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    893\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0membedding_\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/blocks/lib/python3.6/site-packages/sklearn/manifold/_t_sne.py\u001b[0m in \u001b[0;36m_fit\u001b[0;34m(self, X, skip_num_points)\u001b[0m\n\u001b[1;32m    700\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    701\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mn_iter\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0;36m250\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 702\u001b[0;31m             \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"n_iter should be at least 250\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    703\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    704\u001b[0m         \u001b[0mn_samples\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mValueError\u001b[0m: n_iter should be at least 250"
     ]
    }
   ],
   "source": [
    "\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from sklearn.manifold import TSNE\n",
    "\n",
    "def display_tsne_scatterplot_2D(word_vectors, colors):\n",
    "    \n",
    "    # three_dim = TSNE(n_components = 3, random_state=0, perplexity = perplexity, learning_rate = learning_rate, n_iter = iteration).fit_transform(word_vectors)[:,:3]\n",
    "    two_dim = TSNE(n_components = 2, random_state=0, perplexity = 0, learning_rate = 0, n_iter = 500).fit_transform(word_vectors)[:,:2]\n",
    "\n",
    "    print(two_dim)\n",
    "    \n",
    "\n",
    "\n",
    "display_tsne_scatterplot_2D(embeddings, colors)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ]
}